1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
| #include <cstdio> #include <algorithm> #define LS a[a[cur].ls] #define RS a[a[cur].rs] using namespace std; const int maxn = 3e5 + 5; const int mo = 998244353; int pow(int x, int t){ int res = 1; x %= mo; while (t > 0){ if (t & 1) res = 1ll * res * x % mo; x = 1ll * x * x % mo; t >>= 1; } return res; } const int inv = pow(10000, mo - 2); int link[maxn], cnt = 0; struct T{ struct A{ int ls, rs, v, tg; A(){tg = 1;} }a[maxn * 40]; int rt[maxn], tot = 0; void pushdown(int cur){ if (a[cur].tg == 1) return; LS.v = 1ll * LS.v * a[cur].tg % mo; RS.v = 1ll * RS.v * a[cur].tg % mo; LS.tg = 1ll * LS.tg * a[cur].tg % mo; RS.tg = 1ll * RS.tg * a[cur].tg % mo; a[cur].tg = 1; } int upd(int cur, int l, int r, int p, int k){ if (!cur) cur = ++tot; if (l == r){ a[cur].v = k; return cur; } int mid = l + r >> 1; pushdown(cur); if (p <= mid) a[cur].ls = upd(a[cur].ls, l, mid, p, k); else a[cur].rs = upd(a[cur].rs, mid + 1, r, p, k); a[cur].v = (LS.v + RS.v) % mo; return cur; } int Merge(int cur, int v, int l, int r, int pfu, int sfu, int pfv, int sfv, int P){ if (!cur && !v) return 0; pushdown(cur); pushdown(v); if (!cur){ a[v].v = 1ll * a[v].v * (1ll * P * pfu % mo + 1ll * (1 + mo - P) * sfu % mo) % mo; a[v].tg = 1ll * a[v].tg * (1ll * P * pfu % mo + 1ll * (1 + mo - P) * sfu % mo) % mo; return v; } if (!v){ a[cur].v = 1ll * a[cur].v * (1ll * P * pfv % mo + 1ll * (1 + mo - P) * sfv % mo) % mo; a[cur].tg = 1ll * a[cur].tg * (1ll * P * pfv % mo + 1ll * (1 + mo - P) * sfv % mo) % mo; return cur; } int mid = l + r >> 1, t1 = LS.v, t2 = RS.v, t3 = a[a[v].ls].v, t4 = a[a[v].rs].v; a[cur].ls = Merge(a[cur].ls, a[v].ls, l, mid, pfu, (sfu + t2) % mo, pfv, (sfv + t4) % mo, P); a[cur].rs = Merge(a[cur].rs, a[v].rs, mid + 1, r, (pfu + t1) % mo, sfu, (pfv + t3) % mo, sfv, P); a[cur].v = (LS.v + RS.v) % mo; return cur; } int ans = 0; void calc(int cur, int l, int r){ if (l == r){ ans = (ans + 1ll * l * link[l] % mo * a[cur].v % mo * a[cur].v % mo) % mo;
return; } pushdown(cur); int mid = l + r >> 1; calc(a[cur].ls, l, mid); calc(a[cur].rs, mid + 1, r); } }t; struct E{ int to, nxt; }e[maxn << 1]; int head[maxn], tot = 0; void addedge(int u, int v){ e[++tot].to = v, e[tot].nxt = head[u]; head[u] = tot; } int p[maxn]; int v[maxn], tmp[maxn]; void dfs(int cur){ if (head[cur] == 0) t.rt[cur] = t.upd(t.rt[cur], 1, cnt, v[cur], 1); for (int i = head[cur]; i; i = e[i].nxt){ dfs(e[i].to); if (t.rt[cur] != 0) t.rt[cur] = t.Merge(t.rt[cur], t.rt[e[i].to], 1, cnt, 0, 0, 0, 0, 1ll * p[cur] * inv % mo); else t.rt[cur] = t.rt[e[i].to]; } } int n; signed main(){
scanf("%d", &n); for (int fa, i = 1; i <= n; i++){ scanf("%d", &fa); addedge(fa, i); } for (int i = 1; i <= n; i++){ if (head[i] == 0) scanf("%d", v + i), tmp[++cnt] = v[i]; else scanf("%d", p + i); } sort(tmp + 1, tmp + cnt + 1); for (int i = 1; i <= n; i++) if (head[i] == 0){ int p = lower_bound(tmp + 1, tmp + cnt + 1, v[i]) - tmp; link[p] = v[i]; v[i] = p; } dfs(1); t.calc(t.rt[1], 1, cnt); printf("%d\n", t.ans); return 0; }
|